1656H - Equal LCM Subsets - CodeForces Solution


data structures math number theory *3200

Please click on ads to support us..

C++ Code:

#pragma GCC optimize("O2")

#pragma GCC optimize ("unroll-loops")

#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,fma")



#include <bits/stdc++.h>

// #include <bits/extc++.h> 

// #include <ext/pb_ds/assoc_container.hpp>

// #include <ext/pb_ds/tree_policy.hpp>



#define fix(f,n) std::fixed<<std::setprecision(n)<<f



typedef long long ll;



int dx[4] = {1, 0, -1, 0}; 

int dy[4] = {0, -1, 0, 1}; 

char direction[4] = {'D', 'L', 'U', 'R'}; 

 

using namespace std;

// using namespace __gnu_pbds;

 

// template <typename T>

// using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

 

// template <typename T>

// using ordered_multiset = tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>; 



#ifdef ADD_TRACE

#include "utils/debug.h"

#define trace(...) cout<<"Line:"<<__LINE__<<" ", __f(#__VA_ARGS__, __VA_ARGS__)

#else

#define trace(...)

#endif



// struct edge{

// 	int u, v, w;



// 	int other(int node) {

// 		return u ^ v ^ node; 

// 	}



// 	bool operator < (edge other) const {

// 		return w < other.w; 

// 	}

// };



// const int RANDOM = chrono::high_resolution_clock::now().time_since_epoch().count();

// struct chash { // To use most bits rather than just the lowest ones:

// 	const uint64_t C = ll(4e18 * acos(0)) | 71; // large odd number

// 	ll operator()(ll x) const { return __builtin_bswap64((x^RANDOM)*C); }

// };



// mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());



// ll getRand(ll l, ll r) {

// 	uniform_int_distribution<ll> uid(l, r); 

// 	return uid(rng); 

// }



// O(n^2 * log A + n * log A * (log log A) ^ 2) where A = max(S)

//size of primebasis is O(n * log log A) on average, worst case O(n * log A)



template<typename T>

T gcd(T a, T b) {

    return b == 0 ? a : gcd(b, a % b);

}



template<typename T> 

struct PrimeBasis{

	vector<T> basis;

    

	void reduce_pair(T& x, T& y){

        bool to_swap = 0;

        if(x > y){

            to_swap ^= 1;

            swap(x, y);

        } 



        while(x > 1 && y % x == 0){

            y /= x;

            if(x > y){

                to_swap ^= 1;

                swap(x, y);

            }

        }



        if(to_swap) swap(x, y);

    }

    

    

	void solve_inner(int pos, T &val){

        while(basis[pos] % val == 0) basis[pos] /= val;

        vector<T> curr_basis = {basis[pos], val};

        int c_ptr = 1;

        while(c_ptr < curr_basis.size()){

            for(int i=0;i<c_ptr;i++){

                reduce_pair(curr_basis[i], curr_basis[c_ptr]);

                if(curr_basis[c_ptr] == 1) break;

                if(curr_basis[i] == 1) continue;

                T g = gcd(curr_basis[c_ptr], curr_basis[i]);

                if(g > 1){

                    curr_basis[c_ptr] /= g;

                    curr_basis[i] /= g;

                    curr_basis.push_back(g);

                }

            }

            ++c_ptr;

        }

        basis[pos] = curr_basis[0];

        val = curr_basis[1];

        for(int i=2;i<curr_basis.size();i++) if(curr_basis[i] > 1) basis.push_back(curr_basis[i]);

        if(basis[pos] == 1){

            swap(basis[pos], basis.back());

            basis.pop_back();

        }

    }

    

	void add_element(T val){

        for(int i=0;i<basis.size();i++){

            reduce_pair(val, basis[i]);

            if(basis[i] == 1){

                swap(basis[i], basis.back());

                basis.pop_back();

				break; 

            }

            if(val == 1) return;

            if(gcd(basis[i], val) > 1) solve_inner(i, val);

        }

        if(val > 1) basis.push_back(val);

    }

    

	void verify_basis(){

        for(int i=0;i<basis.size();i++){

            for(int j=i+1;j<basis.size();j++){

                assert(gcd(basis[i], basis[j]) == 1);

            }

        }

    }

    

	void verify_element(T ele){

        for(auto &x : basis){

            while(ele % x == 0) ele /= x;

        }

        assert(ele == 1);

    }

    

	vector<int> factorisation(T ele){

        vector<int> factors(basis.size());

        for(int i=0;i<basis.size();i++){

            while(ele % basis[i] == 0){

                factors[i]++;

                ele /= basis[i];

            }

        }

        return factors;

    }

};



namespace Codeforces {

    namespace MostSignificantBit {

        const int SIZE = 1 << 16;

        

        uint8_t table[SIZE];

        

        bool WasInitialized() { return table[2]; }

        

        bool Init() {

            for (int i = 1; i < SIZE; ++i) {

                for (uint8_t pow = 0; pow < 16; ++pow) {

                    if ((i >> pow) & 1) {

                        table[i] = pow;

                    }

                }

            }

            return true;

        }

        

        int Get(int value) { return WasInitialized() || Init(), table[value & (SIZE-1)]; }

    }

    

    struct uint128_t {

        

        uint64_t high, low;

        

        uint128_t(uint64_t high_, uint64_t low_) : high(high_), low(low_) { }

        

        uint128_t(uint64_t low_ = 0) : high(0), low(low_) { }

    };

    

    bool operator==(uint128_t lhs, uint128_t rhs) { return lhs.low == rhs.low && lhs.high == rhs.high; }

    bool operator<=(uint128_t lhs, uint128_t rhs) { return lhs.high < rhs.high || (lhs.high == rhs.high && lhs.low <= rhs.low); }

    bool operator>=(uint128_t lhs, uint128_t rhs) { return rhs <= lhs; }

    bool operator!=(uint128_t lhs, uint128_t rhs) { return !(lhs == rhs); }

    bool operator< (uint128_t lhs, uint128_t rhs) { return !(lhs >= rhs); }

    bool operator> (uint128_t lhs, uint128_t rhs) { return !(lhs <= rhs); }

    

    int mostSignificantBit32(uint32_t value) {

        return (value >> 16) ? MostSignificantBit::Get(value >> 16) + 16 : MostSignificantBit::Get(value);

    }

    

    int mostSignificantBit64(uint64_t value) {

        return (value >> 32) ? mostSignificantBit32(uint32_t(value >> 32)) + 32 : mostSignificantBit32(value & ~uint32_t(0));

    }

    

    int mostSignificantBit(uint128_t value) {

        return value.high ? mostSignificantBit64(value.high) + 64 : mostSignificantBit64(value.low);

    }

    

    uint128_t operator+(uint128_t lhs, uint128_t rhs) {

        uint128_t ret(lhs.high + rhs.high, lhs.low + rhs.low);

        ret.high += (ret.low < lhs.low);

        return ret;

    }

    

    uint128_t operator-(uint128_t lhs, uint128_t rhs) {

        uint128_t ret(lhs.high - rhs.high, lhs.low - rhs.low);

        ret.high -= (lhs.low < ret.low);

        return ret;

    }

    

    uint128_t& operator+=(uint128_t& lhs, uint128_t rhs) { return lhs = lhs + rhs; }

    uint128_t& operator-=(uint128_t& lhs, uint128_t rhs) { return lhs = lhs - rhs; }

    

    uint128_t operator<<(uint128_t lhs, int cnt) {

        if (cnt == 0) { return lhs; }

        if (cnt >= 64) { return uint128_t(lhs.low << (cnt - 64), 0); }

        return uint128_t((lhs.high << cnt) | (lhs.low >> (64-cnt)), lhs.low << cnt);

    }

    

    uint128_t operator>>(uint128_t lhs, int cnt) {

        if (cnt == 0) { return lhs; }

        if (cnt >= 64) { return uint128_t(lhs.high >> (cnt-64)); }

        return uint128_t(lhs.high >> cnt, (lhs.low >> cnt) | (lhs.high << (64 - cnt)));

    }

    

    uint128_t& operator>>=(uint128_t& lhs, int cnt) { return lhs = lhs >> cnt; }

    uint128_t& operator<<=(uint128_t& lhs, int cnt) { return lhs = lhs << cnt; }

    

    uint128_t operator*(uint128_t lhs, uint128_t rhs) {

        uint64_t a32 = lhs.low >> 32, a00 = lhs.low & 0xffffffff;

        uint64_t b32 = rhs.low >> 32, b00 = rhs.low & 0xffffffff;

        uint128_t ret(lhs.high * rhs.low + lhs.low * rhs.high + a32 * b32, a00 * b00);

        return ret + (uint128_t(a32 * b00) << 32) + (uint128_t(a00 * b32) << 32);

    }

    

    void DivMod(uint128_t a, uint128_t b, uint128_t &q, uint128_t &r) {

        assert(b.low | b.high);

        if (a < b) { q = 0, r = a; return; }

        if (a == b) { q = 1, r = 0; return; }

        const int shift = mostSignificantBit(a) - mostSignificantBit(b);

        q = 0, r = a, b <<= shift;

        for (int i = 0; i <= shift; ++i) {

            q <<= 1;

            if (r >= b) { r -= b; q.low |= 1; }

            b >>= 1;

        }

    }

    

    uint128_t operator/(uint128_t lhs, uint128_t rhs) {

        uint128_t div, rem;

        return DivMod(lhs, rhs, div, rem), div;

    }

    

    uint128_t operator%(uint128_t lhs, uint128_t rhs) {

        uint128_t div, rem;

        return DivMod(lhs, rhs, div, rem), rem;

    }

    

    uint128_t& operator/=(uint128_t &lhs, uint128_t rhs) { return lhs = lhs / rhs; }

    uint128_t& operator%=(uint128_t &lhs, uint128_t rhs) { return lhs = lhs % rhs; }

}

using bigint = Codeforces::uint128_t;

istream& operator>>(istream& in, bigint &v){

    string s;

    in>>s;

    for(auto &c : s) v = v * 10 + (c - '0');

    return in;

}

ostream& operator<<(ostream& out, bigint v){

    string s;

    while(v > 0){

        s += '0' + (v % 10).low;

        v /= 10;

    }

    reverse(s.begin(), s.end());

    out<<s;

    return out;

}



void solve() {

	int n, m; 

	cin >> n >> m; 

	

	vector<bigint> a(n), b(m); 

	for(int i = 0; i < n; ++i) {

		cin >> a[i]; 

	}

	for(int i = 0; i < m; ++i) {

		cin >> b[i]; 

	}



	PrimeBasis<bigint> pb; 

	for(int i = 0; i < n; ++i) pb.add_element(a[i]); 

	for(int i = 0; i < m; ++i) pb.add_element(b[i]); 



	// cout << "primebasis: "; 

	// for(auto& x: pb.basis) {

	// 	cout << x << ' '; 

	// }

	cout << '\n'; 



	vector<vector<int>> a_fact(n), b_fact(m); 

	for(int i = 0; i < n; ++i) {

		a_fact[i] = pb.factorisation(a[i]); 

	}

	for(int i = 0; i < m; ++i) {

		b_fact[i] = pb.factorisation(b[i]); 

	}



	trace(a_fact); 



	int factors = (int)a_fact[0].size();



	vector<int> to_sort_a(n), to_sort_b(m);

	vector<bool> del_a(n), del_b(m);  

	iota(to_sort_a.begin(), to_sort_a.end(), 0); 

	iota(to_sort_b.begin(), to_sort_b.end(), 0);



	trace(to_sort_a); 



	vector<int> lcm(factors); 

    vector<vector<int>> sorted_a(factors), sorted_b(factors);

    vector<int> a_ptr(factors), b_ptr(factors); 

    vector<vector<int>> a_dep(n), b_dep(m); 



    for(int i = 0; i < factors; ++i) {

        sort(to_sort_a.begin(), to_sort_a.end(), [&](int a, int b) {

            return a_fact[a][i] > a_fact[b][i]; 	

        });

        sorted_a[i] = to_sort_a; 

        sort(to_sort_b.begin(), to_sort_b.end(), [&](int a, int b){

            return b_fact[a][i] > b_fact[b][i]; 	

        });

        sorted_b[i] = to_sort_b; 

    }



    vector<int> to_check; 

    for(int i = 0; i < factors; ++i) to_check.push_back(i); 



	while(!to_check.empty()) {

        int i = to_check.back(); 

        to_check.pop_back(); 

        while(true) {

            while(a_ptr[i] < n && del_a[sorted_a[i][a_ptr[i]]]) a_ptr[i]++; 

            while(b_ptr[i] < m && del_b[sorted_b[i][b_ptr[i]]]) b_ptr[i]++; 



            if(a_ptr[i] >= n || b_ptr[i] >= m) {

                cout << "NO\n"; 

                return; 

            }



            int a_exp = a_fact[sorted_a[i][a_ptr[i]]][i];

            int b_exp = b_fact[sorted_b[i][b_ptr[i]]][i]; 



            if(a_exp > b_exp) {

                del_a[sorted_a[i][a_ptr[i]]] = true;

                for(int f: a_dep[sorted_a[i][a_ptr[i]]]) {

                    to_check.push_back(f); 

                } 

                a_dep[sorted_a[i][a_ptr[i]]].clear(); 

                ++a_ptr[i];

            } else if(a_exp < b_exp) {

                del_b[sorted_b[i][b_ptr[i]]] = true;

                for(int f: b_dep[sorted_b[i][b_ptr[i]]]) {

                    to_check.push_back(f); 

                }

                b_dep[sorted_b[i][b_ptr[i]]].clear(); 

                ++b_ptr[i]; 

            } else {

                a_dep[sorted_a[i][a_ptr[i]]].push_back(i); 

                b_dep[sorted_b[i][b_ptr[i]]].push_back(i); 

                break; 

            }

        }



        assert(a_fact[sorted_a[i][a_ptr[i]]][i] == b_fact[sorted_b[i][b_ptr[i]]][i]); 

        lcm[i] = a_fact[sorted_a[i][a_ptr[i]]][i]; 

    }

	

	trace(lcm); 



	cout << "YES\n"; 

	vector<int> a_ans, b_ans; 

	for(int i = 0; i < n; ++i) {

		bool ad = true; 

		for(int f = 0; f < factors; ++f) if(a_fact[i][f] > lcm[f]) {

			ad = false; 

			break; 

		}



		if(ad) a_ans.push_back(i); 

	}

	for(int i = 0; i < m; ++i) {

		bool ad = true; 

		for(int f = 0; f < factors; ++f) if(b_fact[i][f] > lcm[f]) {

			ad = false; 

			break; 

		}



		if(ad) b_ans.push_back(i); 

	}



	int ctr = 0; 

	cout << a_ans.size() << ' ' << b_ans.size() << '\n'; 

	for(auto& vec: {a_ans, b_ans}) { 

		for(auto& x: vec) cout << (ctr == 0 ? a[x]: b[x]) << ' ';

		++ctr;  

		cout << '\n'; 

	}

}

 

int main() {

	ios_base::sync_with_stdio(false);

	cin.tie(NULL);

	int t = 1;

	cin >> t;

	for (int i = 0; i < t; i++) {

		solve();

	}

	return 0;

}


Comments

Submit
0 Comments
More Questions

1581A - CQXYM Count Permutations
337A - Puzzles
495A - Digital Counter
796A - Buying A House
67A - Partial Teacher
116A - Tram
1472B - Fair Division
1281C - Cut and Paste
141A - Amusing Joke
112A - Petya and Strings
677A - Vanya and Fence
1621A - Stable Arrangement of Rooks
472A - Design Tutorial Learn from Math
1368A - C+=
450A - Jzzhu and Children
546A - Soldier and Bananas
32B - Borze
1651B - Prove Him Wrong
381A - Sereja and Dima
41A - Translation
1559A - Mocha and Math
832A - Sasha and Sticks
292B - Network Topology
1339A - Filling Diamonds
910A - The Way to Home
617A - Elephant
48A - Rock-paper-scissors
294A - Shaass and Oskols
1213A - Chips Moving
490A - Team Olympiad